{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GTEx MatrixTables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To create MatrixTables containing all variant-gene associations tested in each tissue (including non-significant associations) for [GTEx](https://gtexportal.org/home/datasets) v8.\n",
    "\n",
    "There are two MatrixTables, one is for the eQTL tissue-specific all SNP gene associations data and the other is for the sQTL tissue-specific all SNP gene associations data. \n",
    "\n",
    "Hail Tables for each tissue were already created previously from the data [here](https://console.cloud.google.com/storage/browser/hail-datasets-tmp/GTEx/GTEx_Analysis_v8_QTLs). For eQTL each table is ~7 GiB, and for sQTL each table is ~40 GiB or so. A README describing the fields in the GTEx QTL datasets is available [here](https://storage.googleapis.com/gtex_analysis_v8/single_tissue_qtl_data/README_eQTL_v8.txt).\n",
    "\n",
    "Each MatrixTable has rows keyed by `[\"locus\", \"alleles\"]`, and columns keyed by `[\"tissue\"]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import hail as hl\n",
    "hl.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we can grab a list of the GTEx tissue names:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_tissues = subprocess.run([\"gsutil\", \"-u\", \"broad-ctsa\", \"ls\", \n",
    "                                \"gs://hail-datasets-tmp/GTEx/GTEx_Analysis_v8_QTLs/GTEx_Analysis_v8_eQTL_all_associations\"], \n",
    "                               stdout=subprocess.PIPE)\n",
    "tissue_files = list_tissues.stdout.decode(\"utf-8\").split()\n",
    "tissue_names = [x.split(\"/\")[-1].split(\".\")[0] for x in tissue_files]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take a peek at the tissue names we get to make sure they're what we expect:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tissue_names[0:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can start with the process for the eQTL tables since they are smaller and a bit easier to work with. There are pretty much three steps here\n",
    "  - Generate individual MatrixTables from the existing Hail Tables for each tissue type, there are 49 tissue types in total.\n",
    "  - Perform a multi-way union cols (MWUC) on these 49 MatrixTables to create a single MatrixTable where there is a column for each tissue.\n",
    "  - After the MWUC the resulting MatrixTable has pretty imbalanced partitions (some are KiBs, others are GiBs) so we have to repartition the unioned MatrixTable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### eQTL tissue-specific all SNP gene associations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate individual MatrixTables from the existing Hail Tables for each tissue type (49 total).\n",
    "\n",
    "Write output to `gs://hail-datasets-tmp/GTEx/eQTL_MatrixTables/`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for tissue_name in tissue_names:\n",
    "    print(f\"eQTL: {tissue_name}\")\n",
    "    ht = hl.read_table(f\"gs://hail-datasets-us/GTEx_eQTL_allpairs_{tissue_name}_v8_GRCh38.ht\", _n_partitions=64)\n",
    "\n",
    "    ht = ht.annotate(_gene_id = ht.gene_id, _tss_distance = ht.tss_distance)\n",
    "    ht = ht.drop(\"variant_id\", \"metadata\")\n",
    "    ht = ht.key_by(\"locus\", \"alleles\", \"_gene_id\", \"_tss_distance\")\n",
    "    ht = ht.annotate(**{tissue_name: ht.row_value.drop(\"gene_id\", \"tss_distance\")})\n",
    "    ht = ht.select(tissue_name)\n",
    "\n",
    "    mt = ht.to_matrix_table_row_major(columns=[tissue_name], col_field_name=\"tissue\")\n",
    "    mt = mt.checkpoint(\n",
    "        f\"gs://hail-datasets-tmp/GTEx/eQTL_MatrixTables/GTEx_eQTL_all_snp_gene_associations_{tissue_name}_v8_GRCh38.mt\", \n",
    "        overwrite=False,\n",
    "        _read_if_exists=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To ensure that everything is joined correctly later on, we add both the `_gene_id` and `tss_distance` fields to the table keys here. \n",
    "\n",
    "After the unioned MatrixTable is created we will re-key the rows to just be `[\"locus\", \"alleles\"]`, and rename the fields above back to `gene_id` and `tss_distance` (they will now be row fields)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Perform multi-way union cols (MWUC) on MatrixTables generated above"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function below was used to take a list of MatrixTables and a list with the column key fields and output a single MatrixTable with the columns unioned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "def multi_way_union_cols(mts: List[hl.MatrixTable], column_keys: List[str]) -> hl.MatrixTable:\n",
    "    missing_struct = \"struct{ma_samples: int32, ma_count: int32, maf: float64, pval_nominal: float64, slope: float64, slope_se: float64}\"\n",
    "    \n",
    "    mts = [mt._localize_entries(\"_mt_entries\", \"_mt_cols\") for mt in mts]\n",
    "    \n",
    "    joined = hl.Table.multi_way_zip_join(mts, \"_t_entries\", \"_t_cols\")\n",
    "    joined = joined.annotate(_t_entries_missing = joined._t_entries.map(lambda x: hl.is_missing(x)))\n",
    "    \n",
    "    rows = [(r, joined._t_entries.map(lambda x: x[r])[0])\n",
    "            for r in joined._t_entries.dtype.element_type.fields \n",
    "            if r != \"_mt_entries\"]\n",
    "    \"\"\"\n",
    "    Need to provide a dummy array<struct> for if tissues are not present to make sure missing elements not\n",
    "    dropped from flattened array. \n",
    "    \n",
    "    Otherwise we will get a HailException: length mismatch between entry array and column array in \n",
    "    'to_matrix_table_row_major'.\n",
    "    \"\"\"\n",
    "    entries = [(\"_t_entries_flatten\", \n",
    "                hl.flatten(\n",
    "                    joined._t_entries.map(\n",
    "                        lambda x: hl.if_else(\n",
    "                            hl.is_defined(x), \n",
    "                            x._mt_entries,\n",
    "                            hl.array([\n",
    "                                hl.struct(\n",
    "                                    ma_samples = hl.missing(hl.tint32), \n",
    "                                    ma_count = hl.missing(hl.tint32), \n",
    "                                    maf = hl.missing(hl.tfloat64), \n",
    "                                    pval_nominal = hl.missing(hl.tfloat64), \n",
    "                                    slope = hl.missing(hl.tfloat64), \n",
    "                                    slope_se = hl.missing(hl.tfloat64)\n",
    "                                )\n",
    "                            ])\n",
    "                        )\n",
    "                    )\n",
    "                )\n",
    "               )]\n",
    "    joined = joined.annotate(**dict(rows + entries))\n",
    "    \"\"\"\n",
    "    Also want to make sure that if entry is missing, it is replaced with a missing struct of the same form\n",
    "    at the same index in the array.\n",
    "    \"\"\"\n",
    "    joined = joined.annotate(_t_entries_new = hl.zip(joined._t_entries_missing, \n",
    "                                                     joined._t_entries_flatten, \n",
    "                                                     fill_missing=False))\n",
    "    joined = joined.annotate(\n",
    "        _t_entries_new = joined._t_entries_new.map(\n",
    "            lambda x: hl.if_else(x[0] == True, hl.missing(missing_struct), x[1])\n",
    "        )\n",
    "    )    \n",
    "    joined = joined.annotate_globals(_t_cols = hl.flatten(joined._t_cols.map(lambda x: x._mt_cols)))\n",
    "    joined = joined.drop(\"_t_entries\", \"_t_entries_missing\", \"_t_entries_flatten\")\n",
    "    mt = joined._unlocalize_entries(\"_t_entries_new\", \"_t_cols\", [\"tissue\"])\n",
    "    return mt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can read in each individual MatrixTable and add it to the list we will pass to `multi_way_union_cols`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get list of file paths for individual eQTL MatrixTables\n",
    "list_eqtl_mts = subprocess.run([\"gsutil\", \"-u\", \"broad-ctsa\", \"ls\", \"gs://hail-datasets-tmp/GTEx/eQTL_MatrixTables\"], \n",
    "                               stdout=subprocess.PIPE)\n",
    "eqtl_mts = list_eqtl_mts.stdout.decode(\"utf-8\").split()\n",
    "\n",
    "# Load MatrixTables for each tissue type to store in list for MWUC\n",
    "mts_list = []\n",
    "for eqtl_mt in eqtl_mts:\n",
    "    tissue_name = eqtl_mt.replace(\"gs://hail-datasets-tmp/GTEx/eQTL_MatrixTables/GTEx_eQTL_all_snp_gene_associations_\", \"\")\n",
    "    tissue_name = tissue_name.replace(\"_v8_GRCh38.mt/\", \"\")\n",
    "    print(tissue_name)\n",
    "    \n",
    "    mt = hl.read_matrix_table(eqtl_mt)\n",
    "    mts_list.append(mt)\n",
    "\n",
    "full_mt = multi_way_union_cols(mts_list, [\"tissue\"])\n",
    "full_mt = full_mt.checkpoint(\"gs://hail-datasets-tmp/GTEx/checkpoints/GTEx_eQTL_all_snp_gene_associations_cols_unioned.mt\", \n",
    "                             overwrite=False,\n",
    "                             _read_if_exists=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Repartition unioned MatrixTable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the MWUC the resulting MatrixTable has pretty imbalanced partitions (some are KiBs, others are GiBs) so we want to repartition the unioned MatrixTable. \n",
    "\n",
    "First we can re-key the rows of our MatrixTable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-key rows and repartition\n",
    "full_mt = hl.read_matrix_table(\"gs://hail-datasets-tmp/GTEx/checkpoints/GTEx_eQTL_all_snp_gene_associations_cols_unioned.mt\", \n",
    "                               _n_partitions=1000)\n",
    "full_mt = full_mt.key_rows_by(\"locus\", \"alleles\")\n",
    "full_mt = full_mt.checkpoint(\"gs://hail-datasets-tmp/GTEx/GTEx_eQTL_all_snp_gene_associations.mt\", \n",
    "                             overwrite=False, \n",
    "                             _read_if_exists=True)\n",
    "full_mt.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I tried reading in the MatrixTable with `_n_partitions=1000` to see how our partitions would look, but we still had a few that were much larger than the rest. So after this I ended up doing using `repartition` with a full shuffle, and it balanced things out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add metadata to globals and write final MatrixTable to hail-datasets-us\n",
    "full_mt = hl.read_matrix_table(\"gs://hail-datasets-tmp/GTEx/GTEx_eQTL_all_snp_gene_associations.mt\")\n",
    "full_mt = full_mt.repartition(1000, shuffle=True)\n",
    "\n",
    "n_rows, n_cols = full_mt.count()\n",
    "n_partitions = full_mt.n_partitions()\n",
    "\n",
    "full_mt = full_mt.rename({\"_gene_id\": \"gene_id\", \"_tss_distance\": \"tss_distance\"})\n",
    "full_mt = full_mt.annotate_globals(\n",
    "    metadata = hl.struct(name = \"GTEx_eQTL_all_snp_gene_associations\",\n",
    "                         reference_genome = \"GRCh38\",\n",
    "                         n_rows = n_rows,\n",
    "                         n_cols = n_cols,\n",
    "                         n_partitions = n_partitions)\n",
    ")\n",
    "# Final eQTL MatrixTable is ~224 GiB w/ 1000 partitions\n",
    "full_mt.write(\"gs://hail-datasets-us/GTEx_eQTL_all_snp_gene_associations_v8_GRCh38.mt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we have a single MatrixTable for the GTEx eQTL data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hl.read_matrix_table(\"gs://hail-datasets-us/GTEx_eQTL_all_snp_gene_associations_v8_GRCh38.mt\").describe()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "pycharm": {
     "name": "#%% raw\n"
    }
   },
   "source": [
    "----------------------------------------\n",
    "Global fields:\n",
    "    'metadata': struct {\n",
    "        name: str, \n",
    "        reference_genome: str, \n",
    "        n_rows: int32, \n",
    "        n_cols: int32, \n",
    "        n_partitions: int32\n",
    "    }\n",
    "----------------------------------------\n",
    "Column fields:\n",
    "    'tissue': str\n",
    "----------------------------------------\n",
    "Row fields:\n",
    "    'locus': locus<GRCh38>\n",
    "    'alleles': array<str>\n",
    "    'gene_id': str\n",
    "    'tss_distance': int32\n",
    "----------------------------------------\n",
    "Entry fields:\n",
    "    'ma_samples': int32\n",
    "    'ma_count': int32\n",
    "    'maf': float64\n",
    "    'pval_nominal': float64\n",
    "    'slope': float64\n",
    "    'slope_se': float64\n",
    "----------------------------------------\n",
    "Column key: ['tissue']\n",
    "Row key: ['locus', 'alleles']\n",
    "----------------------------------------\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}